// fits Wakefield's hierarchical model for ecological inference using
// Wakefield's normal approximation to the binomial convolution likelihood.
// Covariates incorporated to implement King's extended mdeol 
//
// Modified version of hierNA.cc from MCMCpack
//
// Updated 10/31/03 by Kevin Corder

#include <iostream> 
#include "Scythe_Matrix.h"
#include "Scythe_Simulate.h"
#include "Scythe_Stat.h"
#include "Scythe_Math.h"
#include "Scythe_LA.h"
#include "Scythe_IDE.h"

using namespace SCYTHE;
using namespace std;

// NOTE: USES TWO FUNCTIONS FROM MCMCregress.cc 
// HAD TO ADD THESE HERE WITH DIFFERENT NAMES TO COMPILE CORRECTLY
// regress_beta_update2
// sigma_2_update2



// linear regression with Gaussian errors beta update (multivariate Normal 
// prior)

Matrix <double> 
regress_beta_update2 (const Matrix <double> &XpX, const Matrix <double> &XpY,
		     const Matrix <double> &b, const Matrix <double> &B,
		     const double sigma2)
{

  // this function gets the cross-product matrix X'X and the matrix X'Y
  // to minimize the amount of computation within the function

  int k = XpX.cols ();
  
  Matrix <double> sig_beta = invpd (B + XpX * (1.0 / sigma2));
  Matrix <double> C = cholesky (sig_beta);
  Matrix <double> betahat = sig_beta * (XpY * (1.0 / sigma2) + B * b);
  Matrix <double> beta_update = betahat + C * rnorm (k, 1);

  return beta_update;
}

// linear regression with Gaussian errors sigma2 update (inverse-Gamma
// prior)

double
regress_sigma2_update2 (const Matrix <double> &X, const Matrix <double> &Y,
		       const Matrix <double> &beta, const double nu,
		       const double delta)
{

  Matrix <double> e = Y - X * beta;
  Matrix <double> SSE = crossprod (e);
  double nu2 = (nu + X.rows ()) * 0.5;
  double delta2 = (delta + SSE[0]) * 0.5;
  double sigma2_update = rigamma (nu2, delta2);

  return sigma2_update;
}



// DEFINE hierEX function

extern "C"{
 
void hierEXC(double* sample, const int* samrow, const int* samcol,
      const double* Rr0, const double* Rr1, const double* Rc0,
      const double* Rc1, const int* Rntables, const int* Rburnin,
      const int* Rmcmc, const int* Rthin, 
      const double* Rnu0, const double* Rdelta0,
      const double* Rnu1, const double* Rdelta1,
      const int* Rverbose, const double* Rtune, 
      const int* Rseed, int* accepts,
      const double *X, const int *Xrow, const int *Xcol,
      const double *b0start,
      const int *b0startrow, const int *b0startcol,
      const double *b1start,
      const int *b1startrow, const int *b1startcol,
      const double *b0, const int *b0row, const int *b0col,
      const double *B0, const int *B0row, const int *B0col,
      const double *b1, const int *b1row, const int *b1col,
      const double *B1, const int *B1row, const int *B1col,
      const double *sigma20start,
      const double *sigma21start          
      ){

// MCMC-related quantities
    
    int burnin = *Rburnin;
    int mcmc =   *Rmcmc;
    int thin =   *Rthin;
    int tot_iter = *Rburnin + *Rmcmc;
    double tune = *Rtune;

// put together matrices
// note: X matrix multiplied by sqrt(weight) in .r function

    int ntables = *Rntables;
    int verbose = *Rverbose;

    Matrix<double> r0(ntables, 1, Rr0);
    Matrix<double> r1(ntables, 1, Rr1);
    Matrix<double> c0(ntables, 1, Rc0);
    Matrix<double> c1(ntables, 1, Rc1);

    Matrix<double> N = c0 + c1;
    
    Matrix <double> Msample (samcol[0], samrow[0], sample);
    Msample = t (Msample);
    Matrix <double> MX (Xcol[0], Xrow[0], X);
    MX = t (MX);
    
    Matrix <double> Mbeta0start (b0startcol[0], b0startrow[0], b0start);
    Mbeta0start = t (Mbeta0start);
    Matrix <double> Mb0 (b0col[0], b0row[0], b0);
    Mb0 = t (Mb0);
    Matrix <double> MB0 (B0col[0], B0row[0], B0);
    MB0 = t (MB0);
    
    Matrix <double> Mbeta1start (b1startcol[0], b1startrow[0], b1start);
    Mbeta1start = t (Mbeta1start);
    Matrix <double> Mb1 (b1col[0], b1row[0], b1);
    Mb1 = t (Mb1);
    Matrix <double> MB1 (B1col[0], B1row[0], B0);
    MB1 = t (MB1);
    
// define constants and form cross-product matrices

    int k = MX.cols ();
    Matrix <double> XpX = crossprod (MX);
    

// load data
// table notation is:
// --------------------
//   Y0  |     | r0
// --------------------
//   Y1  |     | r1
// --------------------
//   c0  | c1  | N

  
   
    // initialize seed (mersenne twister / use default seed unless specified)

    if(*Rseed==0) set_mersenne_seed(5489UL);
    else set_mersenne_seed(*Rseed);

    
 
        
     
    // storage matrices

    Matrix<double> p0mat(mcmc/thin, ntables);
    Matrix<double> p1mat(mcmc/thin, ntables);
    Matrix<double> beta0m (mcmc/thin, k);
    Matrix<double> sigma20m (mcmc/thin, 1);
    Matrix<double> beta1m (mcmc/thin, k);
    Matrix<double> sigma21m (mcmc/thin, 1);
    int count = 0;

    
    // starting values

    Matrix<double> p0 = ones<double>(ntables,1)*0.5;
    Matrix<double> p1 = ones<double>(ntables,1)*0.5;
    Matrix<double> theta0 = log(p0/(1.0 - p0));
    Matrix<double> theta1 = log(p1/(1.0 - p1));
    
    // start with weighted=unweighted
 
    Matrix<double> wtheta0=theta0;
    Matrix<double> wtheta1=theta1;

    Matrix<double> XpY0 = t (MX) * wtheta0;
    Matrix<double> XpY1 = t (MX) * wtheta1;

    Matrix<double> beta0 = Mbeta0start;
    double sigma20 = sigma20start[0];
    Matrix<double> beta1 = Mbeta1start;
    double sigma21 = sigma21start[0];

    double nu0 = *Rnu0;
    double delta0= *Rdelta0;
    double nu1= *Rnu1;
    double delta1= *Rdelta1;
    
    Matrix <double> Xbeta0=MX*Mbeta0start;
    Matrix <double> Xbeta1=MX*Mbeta1start;
   
    Matrix<double> logjumpdens_cur = ones<double>(ntables,1)*1e20;
    
    // tomography line quantities
    // p0 on X axis and p1 on Y axis
    Matrix<double> r0frac = r0/(r0+r1);
    Matrix<double> r1frac = r1/(r0+r1);
    Matrix<double> c0frac = c0/(c0+c1);
    Matrix<double> c1frac = c1/(c0+c1);    
    Matrix<double> tfrac=c0/(r0+r1);
    Matrix<double> intercept = c0/r1;   // intercept of tomography line 
    Matrix<double> slope = -1*(r0/r1);  // slope of tomography line
    Matrix<double> orthoSD = tune/sqrt(N); // sd for MH sampline
    Matrix<double> orthoVar = pow(orthoSD, 2);
    Matrix<double> p0min(ntables,1);
    Matrix<double> p0max(ntables,1);
    
    // calculate min and max possible values of p0
    for (int i=0; i<ntables; ++i){
      p0min[i] = SCYTHE::max((1.0 - intercept[i])/slope[i], 0.0);
      p0max[i] = SCYTHE::min(-1*intercept[i]/slope[i], tfrac[i]);
    }
    
    // adjust p0min and p0max so all of [0,1]^2 is covered by the 
    // proposal density in the MH step
    for (int i=0; i<ntables; ++i){
      double run = ::sqrt(4.0 * ::pow(orthoSD[i], 2) * 
			  (1.0 + 1.0/::pow(slope[i], 2))) /
	(2.0*(1.0 + 1.0 / ::pow(slope[i], 2)));
      p0min[i] = p0min[i] - run;
      p0max[i] = p0max[i] + run;
    }
    
     
    for (int iter=0; iter<tot_iter; ++iter){
      
        
       for (int i=0; i<ntables; ++i){
	
	// sample (p0,p1)|r0,r1,c0,c1
	
	// sample candidate values of p0 and p1
	double u = runif()*(p0max[i]-p0min[i]) + p0min[i]; 
	double length = rnorm(0.0, orthoSD[i]);
	double s = sgn(length);
	length = fabs(length);
	double run = s * ::sqrt(4.0 * ::pow(length, 2) * 
				(1.0 + 1.0/::pow(slope[i], 2))) /
	  (2.0*(1.0 + 1.0 / ::pow(slope[i], 2)));
	double rise = -1.0 * run/slope[i];
	
       // the candidate values
	double p0_can = u + run;
	double p1_can = intercept[i] + slope[i]*(u) + rise;
	double theta0_can = ::log(p0_can/(1.0 - p0_can));
	double theta1_can = ::log(p1_can/(1.0 - p1_can));
        double wtheta0_can=theta0_can*sqrt(N[i]);
        double wtheta1_can=theta1_can*sqrt(N[i]);

        // log density ordinates
	double logjumpdens_can = lndnorm(length, 0.0, orthoSD[i]);
	double logprior_can, loglike_can, logpost_can;
	if ((p0_can < 1.0) && (p0_can > 0.0) && (p1_can < 1.0) && 
	    (p1_can > 0.0)){
	  logprior_can = lndnorm(wtheta0_can, Xbeta0[i], sqrt(sigma20)) + 
	    lndnorm(wtheta1_can, Xbeta1[i], sqrt(sigma21));
	  loglike_can = lndnorm(c0[i], r0[i]*p0_can + r1[i]*p1_can,
				sqrt(r0[i]*p0_can*(1.0-p0_can) + 
				     r1[i]*p1_can*(1.0-p1_can)));
	  logpost_can = loglike_can + logprior_can;
	}
	else{
	  logpost_can = ::log(0);
	}


       double logprior_cur = lndnorm(theta0[i], Xbeta0[i], sqrt(sigma20)) + 
	  lndnorm(theta1[i], Xbeta1[i], sqrt(sigma21));

// End updates
	
	double loglike_cur = lndnorm(c0[i], r0[i]*p0[i] + r1[i]*p1[i],
				     sqrt(r0[i]*p0[i]*(1.0-p0[i]) + 
					  r1[i]*p1[i]*(1.0-p1[i])));
	
	double logpost_cur = loglike_cur + logprior_cur;
	
	
	double alpha = ::exp(logpost_can - logpost_cur + logjumpdens_cur[i] - 
			     logjumpdens_can);
	
	if (runif() < alpha){
	  p0[i] = p0_can;
	  p1[i] = p1_can;
	  theta0[i] = theta0_can;
	  theta1[i] = theta1_can;
          wtheta0[i] = theta0[i]*sqrt(N[i]);
          wtheta1[i] = theta1[i]*sqrt(N[i]);
	  logjumpdens_cur[i] = logjumpdens_can;
	  ++accepts[0];
	}


    	


       //  sample beta0 and beta1, sample sigma20 sigma21 and update Xbeta0 Xbeta1

       
        
        XpY0 = t (MX) * wtheta0;
        XpY1 = t (MX) * wtheta1;

      beta0 = regress_beta_update2 (XpX, XpY0, Mb0, MB0, sigma20);
      sigma20 = (regress_sigma2_update2 (MX, wtheta0, beta0, nu0, delta0))/mean(N); 

      beta1 = regress_beta_update2 (XpX, XpY1, Mb1, MB1, sigma21);
      sigma21 = regress_sigma2_update2 (MX, wtheta1, beta1, nu1, delta1)/mean(N); 

        Xbeta0= MX *beta0;
        Xbeta1= MX *beta1;



 
// if after burnin store samples
	if ((iter >= burnin) && ((iter%thin)==0)){
	 p0mat(count,i) = p0[i];
         p1mat(count,i) = p1[i];
         sigma20m (count,0) = sigma20;
         sigma21m (count,0) = sigma21;
         for (int j = 0; j < k; j++)
        {
		beta0m (count, j) = beta0[j];
	        beta1m (count, j) = beta1[j];
      	  }
      }

}  // end tables loop 
      
if ((iter >= burnin) && ((iter%thin)==0)) ++count;


// print output to screen
      if (verbose==1 && (iter%25)==0){
	cout << "MCMChierEX iteration = " << iter <<  endl;
	cout << " MH acceptance rate = " << 
	  static_cast<double>(accepts[0]) / static_cast<double>(iter) / 
	  static_cast<double>(ntables) << endl << endl;
      }
      
    }
    

    // return sample
   Matrix<double> storagem = cbind(p0mat, p1mat);
    storagem = cbind (storagem, beta0m);
    storagem = cbind (storagem, beta1m);
    storagem = cbind (storagem, sigma20m);
    storagem = cbind (storagem, sigma21m);

   int loop = samrow[0] * samcol[0];
    for (int i = 0; i < loop; ++i)
      sample[i] = storagem[i];
  }

} // extern "C"

